Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix static generation when compiling! #28937

Merged
merged 42 commits into from
Feb 15, 2024
Merged

Fix static generation when compiling! #28937

merged 42 commits into from
Feb 15, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Feb 9, 2024

What does this PR do?

Fixes the static cache generation. Comes with #27931

thanks @OlivierDehaene for the insight

https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 benchmark

  • fixes issue with FlashAttention: when the cache is padded you need the full attention mask, otherwise generations will be wrong with generate because the first forward will be fully causal.
  • fixes graph runs: the cache positions have to be stateless, they are otherwise ignored by the model and the compiled generation are random
  • fixes potential BC by guarding the use of cache positions

FA2 potential fix if compiled worked:

            # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails
            if (cache_position is not None):  
                key_states = key_states[:, :, : cache_position[-1] + 1, :]
                value_states = value_states[:, :, : cache_position[-1] + 1, :]

but I have slowdowns:
Slicing
image
vs no Slicing
image

@ArthurZucker ArthurZucker changed the title wow I was scared! Fix static generation when compiling! Feb 9, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Feb 9, 2024

I'm not sure adding a new argument cache_position to the forward call of the model is strictly backwards compatible. Here's an example to motivate this.

The following works on transformers==4.37.2:

import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer

model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM", attn_implementation="eager")
tokenizer = LlamaTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")

# random input id
inputs = tokenizer("Hey there", return_tensors="pt", return_attention_mask=True)

position_ids = inputs.attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(inputs.attention_mask == 0, 1)

with torch.no_grad():
    logits = model.forward(**inputs, position_ids=position_ids).logits

If we run the same code on this PR, we get the following error:

  File "/Users/sanchitgandhi/transformers/src/transformers/models/llama/modeling_llama.py", line 352, in forward
    attn_weights = attn_weights + causal_mask
                   ~~~~~~~~~~~~~^~~~~~~~~~~~~
RuntimeError: The size of tensor a (3) must match the size of tensor b (2048) at non-singleton dimension 4
Full traceback:
  File "/Users/sanchitgandhi/transformers/debug_llama.py", line 14, in <module>
    logits = model.forward(**inputs, position_ids=position_ids).logits
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/models/llama/modeling_llama.py", line 1106, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/models/llama/modeling_llama.py", line 950, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/models/llama/modeling_llama.py", line 694, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/models/llama/modeling_llama.py", line 352, in forward
    attn_weights = attn_weights + causal_mask
                   ~~~~~~~~~~~~~^~~~~~~~~~~~~
RuntimeError: The size of tensor a (3) must match the size of tensor b (2048) at non-singleton dimension 4

This is because cache_positions is not specified to the forward call, and so defaults to None. When we do our reshape in the attention layer:

causal_mask = attention_mask[ :, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

instead of reshaping to [ :, :, cache_position, : key_states.shape[-2]], we reshape to [ :, :, None, : key_states.shape[-2]]. So instead of slicing, we insert an extra dimension! This gives the size mismatch when we add the attention mask to the weights. The user needs to specify cache_position as an argument to the forward call in order for this to work.

Overall, I think we should avoid adding extra arguments that require code changes from the user, especially to the top-level modules which are already highly-used. What about a design more like Flax where we keep track of the cache_position internally in the StaticCache abstraction? This then requires no changes from the user

@ArthurZucker
Copy link
Collaborator Author

We can make it BC! this PR is not ready yet, but generate should check the past key value class and if signature can take cache_position, give them. Something like that.

I'll work on making it BC! :)

Comment on lines 924 to 930
past_seen_tokens = 0
if use_cache and not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(inputs_embeds.shape[1]) # kept for BC (cache positions)

if cache_position is None:
cache_position = torch.arange(past_seen_tokens, past_seen_tokens+inputs_embeds.shape[1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has to be kept for BC

Comment on lines 1054 to 1060
if attention_mask is None:
return None
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and (torch.all(attention_mask == 1)):
return None
if is_tracing and seq_length == 1:
return None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this failed generations, deal with it later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fxmarty I am warning you in advance 🥶 you might have to do something similar to the prepared_4d_sdpa but this is a lot simpler so for the better

Comment on lines 1209 to 1213
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(past_length, past_length+input_ids.shape[1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kept for BC as well, generate should handle cache positions IMO

@ArthurZucker ArthurZucker marked this pull request as ready for review February 12, 2024 08:08
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-approving, as the overall PR shape looks good to me 👍

(btw, this PR is blocking further work on generate, as llama + generate + dynamic cache is not correct at the moment and I want to standardize the interface of the different cache classes to match the static cache)

@ArthurZucker
Copy link
Collaborator Author

Thanks, merging asap

@ArthurZucker
Copy link
Collaborator Author

image

Slow tests are happy

Comment on lines -4779 to +4781
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beam search will split the cache positions otherwise

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the huge work ! I left some minor comments that should be addressed before merging IMO, otherwise we might introduce some breaking change for users that use our public classes without explicit positional arguments

src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@younesbelkada
Copy link
Contributor

Example of a breaking behaviour that I introduced while working on FA2: #25598 (comment) so we should be careful when adding new args in our modules

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much !

@ArthurZucker ArthurZucker merged commit f3788b0 into main Feb 15, 2024
21 checks passed
@ArthurZucker ArthurZucker deleted the fix-static-kv-cache branch February 15, 2024 05:27
hackyon pushed a commit to hackyon/transformers that referenced this pull request Feb 15, 2024
* wow I was scared!

* fix everything

* nits

* make it BC?

* add todo

* nits

* is_tracing should still be used to pass tracing tests

* nits

* some nits to make sure genration works with static cache uncompiled

* fix sdpa

* fix FA2 for both static and dynamic in a better way?

* style

* fix-copies

* fix fix copies

* fix sequential beam searcg

* style

* use `keys_to_ignore`

* nit

* correct dtype inference when init

* :( the fix for FA2 is still not optimal to investigate!

* styling

* nits

* nit

* this might work better

* add comment

* Update src/transformers/models/llama/modeling_llama.py

* "position_ids" -> "cache_position"

* style

* nit

* Remove changes that should no be propagatted just yet

* Apply suggestions from code review

* Styling

* make sure we raise an errir for static cache with FA2 enabled

* move  to the bottom of the signature

* style

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/llama/modeling_llama.py

* nit in the name

---------

Co-authored-by: Younes Belkada <[email protected]>
@ArthurZucker ArthurZucker mentioned this pull request Feb 20, 2024
@alanwaketan
Copy link
Contributor

Hey @ArthurZucker, I discovered that this change actually breaks TPU...

Now, TPU training with FSDPv2 will produce loss with NaN. I haven't looked into your PR so I'm not sure why. Just bisecting til this change.

@ArthurZucker
Copy link
Collaborator Author

Mmm this might be a ROPE issue? #29109 might also play

@learning-chip
Copy link

Hi @ArthurZucker I run your benchmark script with both transformers 4.38.0 and 4.38.2 but got error:

Traceback (most recent call last):
  File "/home/best_benchmark.py", line 99, in <module>
    generated_ids[:, cache_position] = input_ids.to(device).to(torch.int)
RuntimeError: shape mismatch: value tensor of shape [1686] cannot be broadcast to indexing result of shape [1, 2048]

@ArthurZucker
Copy link
Collaborator Author

It is probably out of date! I'll update it

@ArthurZucker
Copy link
Collaborator Author

We'll actually push a full benchmark in transformers to make sur we always track this!

itazap pushed a commit that referenced this pull request May 14, 2024
* wow I was scared!

* fix everything

* nits

* make it BC?

* add todo

* nits

* is_tracing should still be used to pass tracing tests

* nits

* some nits to make sure genration works with static cache uncompiled

* fix sdpa

* fix FA2 for both static and dynamic in a better way?

* style

* fix-copies

* fix fix copies

* fix sequential beam searcg

* style

* use `keys_to_ignore`

* nit

* correct dtype inference when init

* :( the fix for FA2 is still not optimal to investigate!

* styling

* nits

* nit

* this might work better

* add comment

* Update src/transformers/models/llama/modeling_llama.py

* "position_ids" -> "cache_position"

* style

* nit

* Remove changes that should no be propagatted just yet

* Apply suggestions from code review

* Styling

* make sure we raise an errir for static cache with FA2 enabled

* move  to the bottom of the signature

* style

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/llama/modeling_llama.py

* nit in the name

---------

Co-authored-by: Younes Belkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants